{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5c4d70b6",
   "metadata": {},
   "source": [
    "# Speculative Prompt Caching\n",
    "\n",
    "This cookbook demonstrates \"Speculative Prompt Caching\" - a pattern that reduces time-to-first-token (TTFT) by warming up the cache while users are still formulating their queries.\n",
    "\n",
    "**Without Speculative Caching:**\n",
    "1. User types their question (3 seconds)\n",
    "2. User submits question\n",
    "3. API loads context into cache AND generates response\n",
    "\n",
    "**With Speculative Caching:**\n",
    "1. User starts typing (cache warming begins immediately)\n",
    "2. User continues typing (cache warming completes in background)\n",
    "3. User submits question\n",
    "4. API uses warm cache to generate response"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66720f9a",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "First, let's install the required packages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "9a0adb31",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install anthropic httpx --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "c4a98035",
   "metadata": {},
   "outputs": [],
   "source": [
    "import asyncio\n",
    "import copy\n",
    "import datetime\n",
    "import time\n",
    "\n",
    "import httpx\n",
    "from anthropic import AsyncAnthropic\n",
    "\n",
    "# Configuration constants\n",
    "MODEL = \"claude-sonnet-4-5\"\n",
    "SQLITE_SOURCES = {\n",
    "    \"btree.h\": \"https://sqlite.org/src/raw/18e5e7b2124c23426a283523e5f31a4bff029131b795bb82391f9d2f3136fc50?at=btree.h\",\n",
    "    \"btree.c\": \"https://sqlite.org/src/raw/63ca6b647342e8cef643863cd0962a542f133e1069460725ba4461dcda92b03c?at=btree.c\",\n",
    "}\n",
    "DEFAULT_CLIENT_ARGS = {\n",
    "    \"system\": \"You are an expert systems programmer helping analyze database internals.\",\n",
    "    \"max_tokens\": 4096,\n",
    "    \"temperature\": 0,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49d56e2f",
   "metadata": {},
   "source": [
    "## Helper Functions\n",
    "\n",
    "Let's set up the functions to download our large context and prepare messages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "08b7a5a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "async def get_sqlite_sources() -> dict[str, str]:\n",
    "    print(\"Downloading SQLite source files...\")\n",
    "\n",
    "    source_files = {}\n",
    "    start_time = time.time()\n",
    "\n",
    "    async with httpx.AsyncClient(timeout=30.0) as client:\n",
    "        tasks = []\n",
    "\n",
    "        async def download_file(filename: str, url: str) -> tuple[str, str]:\n",
    "            response = await client.get(url, follow_redirects=True)\n",
    "            response.raise_for_status()\n",
    "            print(f\"Successfully downloaded {filename}\")\n",
    "            return filename, response.text\n",
    "\n",
    "        for filename, url in SQLITE_SOURCES.items():\n",
    "            tasks.append(download_file(filename, url))\n",
    "\n",
    "        results = await asyncio.gather(*tasks)\n",
    "        source_files = dict(results)\n",
    "\n",
    "    duration = time.time() - start_time\n",
    "    print(f\"Downloaded {len(source_files)} files in {duration:.2f} seconds\")\n",
    "    return source_files\n",
    "\n",
    "\n",
    "async def create_initial_message():\n",
    "    sources = await get_sqlite_sources()\n",
    "    # Prepare the initial message with the source code as context.\n",
    "    # A Timestamp is included to prevent cache sharing across different runs.\n",
    "    initial_message = {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": [\n",
    "            {\n",
    "                \"type\": \"text\",\n",
    "                \"text\": f\"\"\"\n",
    "Current time: {datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")}\n",
    "\n",
    "Source to Analyze:\n",
    "\n",
    "btree.h:\n",
    "```c\n",
    "{sources[\"btree.h\"]}\n",
    "```\n",
    "\n",
    "btree.c:\n",
    "```c\n",
    "{sources[\"btree.c\"]}\n",
    "```\"\"\",\n",
    "                \"cache_control\": {\"type\": \"ephemeral\"},\n",
    "            }\n",
    "        ],\n",
    "    }\n",
    "    return initial_message\n",
    "\n",
    "\n",
    "async def sample_one_token(client: AsyncAnthropic, messages: list):\n",
    "    \"\"\"Send a single-token request to warm up the cache\"\"\"\n",
    "    args = copy.deepcopy(DEFAULT_CLIENT_ARGS)\n",
    "    args[\"max_tokens\"] = 1\n",
    "    await client.messages.create(\n",
    "        messages=messages,\n",
    "        model=MODEL,\n",
    "        **args,\n",
    "    )\n",
    "\n",
    "\n",
    "def print_query_statistics(response, query_type: str) -> None:\n",
    "    print(f\"\\n{query_type} query statistics:\")\n",
    "    print(f\"\\tInput tokens: {response.usage.input_tokens}\")\n",
    "    print(f\"\\tOutput tokens: {response.usage.output_tokens}\")\n",
    "    print(f\"\\tCache read input tokens: {getattr(response.usage, 'cache_read_input_tokens', '---')}\")\n",
    "    print(\n",
    "        f\"\\tCache creation input tokens: {getattr(response.usage, 'cache_creation_input_tokens', '---')}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b183621",
   "metadata": {},
   "source": [
    "## Example 1: Standard Prompt Caching (Without Speculative Caching)\n",
    "\n",
    "First, let's see how standard prompt caching works. The user types their question, then we send the entire context + question to the API:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "b16e9048",
   "metadata": {},
   "outputs": [],
   "source": [
    "async def standard_prompt_caching_demo():\n",
    "    client = AsyncAnthropic()\n",
    "\n",
    "    # Prepare the large context\n",
    "    initial_message = await create_initial_message()\n",
    "\n",
    "    # Simulate user typing time (in real app, this would be actual user input)\n",
    "    print(\"User is typing their question...\")\n",
    "    await asyncio.sleep(3)  # Simulate 3 seconds of typing\n",
    "    user_question = \"What is the purpose of the BtShared structure?\"\n",
    "    print(f\"User submitted: {user_question}\")\n",
    "\n",
    "    # Now send the full request (context + question)\n",
    "    full_message = copy.deepcopy(initial_message)\n",
    "    full_message[\"content\"].append(\n",
    "        {\"type\": \"text\", \"text\": f\"Answer the user's question: {user_question}\"}\n",
    "    )\n",
    "\n",
    "    print(\"\\nSending request to API...\")\n",
    "    start_time = time.time()\n",
    "\n",
    "    # Measure time to first token\n",
    "    first_token_time = None\n",
    "    async with client.messages.stream(\n",
    "        messages=[full_message],\n",
    "        model=MODEL,\n",
    "        **DEFAULT_CLIENT_ARGS,\n",
    "    ) as stream:\n",
    "        async for text in stream.text_stream:\n",
    "            if first_token_time is None and text.strip():\n",
    "                first_token_time = time.time() - start_time\n",
    "                print(f\"\\n🕐 Time to first token: {first_token_time:.2f} seconds\")\n",
    "                break\n",
    "\n",
    "        # Get the full response\n",
    "        response = await stream.get_final_message()\n",
    "\n",
    "    total_time = time.time() - start_time\n",
    "    print(f\"Total response time: {total_time:.2f} seconds\")\n",
    "    print_query_statistics(response, \"Standard Caching\")\n",
    "\n",
    "    return first_token_time, total_time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "6dfa2ad1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading SQLite source files...\n",
      "Successfully downloaded btree.h\n",
      "Successfully downloaded btree.c\n",
      "Downloaded 2 files in 0.30 seconds\n",
      "User is typing their question...\n",
      "User submitted: What is the purpose of the BtShared structure?\n",
      "\n",
      "Sending request to API...\n",
      "\n",
      "🕐 Time to first token: 20.87 seconds\n",
      "Total response time: 28.32 seconds\n",
      "\n",
      "Standard Caching query statistics:\n",
      "\tInput tokens: 22\n",
      "\tOutput tokens: 362\n",
      "\tCache read input tokens: 0\n",
      "\tCache creation input tokens: 151629\n"
     ]
    }
   ],
   "source": [
    "# Run the standard demo\n",
    "standard_ttft, standard_total = await standard_prompt_caching_demo()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57d0c669",
   "metadata": {},
   "source": [
    "## Example 2: Speculative Prompt Caching\n",
    "\n",
    "Now let's see how speculative prompt caching improves TTFT by warming the cache while the user is typing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "c4ca7484",
   "metadata": {},
   "outputs": [],
   "source": [
    "async def speculative_prompt_caching_demo():\n",
    "    client = AsyncAnthropic()\n",
    "\n",
    "    # The user has a large amount of context they want to interact with,\n",
    "    # in this case it's the sqlite b-tree implementation (~150k tokens).\n",
    "    initial_message = await create_initial_message()\n",
    "\n",
    "    # Start speculative caching while user is typing\n",
    "    print(\"User is typing their question...\")\n",
    "    print(\"🔥 Starting cache warming in background...\")\n",
    "\n",
    "    # While the user is typing out their question, we sample a single token\n",
    "    # from the context the user is going to be interacting with with explicit\n",
    "    # prompt caching turned on to warm up the cache.\n",
    "    cache_task = asyncio.create_task(sample_one_token(client, [initial_message]))\n",
    "\n",
    "    # Simulate user typing time\n",
    "    await asyncio.sleep(3)  # Simulate 3 seconds of typing\n",
    "    user_question = \"What is the purpose of the BtShared structure?\"\n",
    "    print(f\"User submitted: {user_question}\")\n",
    "\n",
    "    # Ensure cache warming is complete\n",
    "    await cache_task\n",
    "    print(\"✅ Cache warming completed!\")\n",
    "\n",
    "    # Prepare messages for cached query. We make sure we\n",
    "    # reuse the same initial message as was cached to ensure we have a cache hit.\n",
    "    cached_message = copy.deepcopy(initial_message)\n",
    "    cached_message[\"content\"].append(\n",
    "        {\"type\": \"text\", \"text\": f\"Answer the user's question: {user_question}\"}\n",
    "    )\n",
    "\n",
    "    print(\"\\nSending request to API (with warm cache)...\")\n",
    "    start_time = time.time()\n",
    "\n",
    "    # Measure time to first token\n",
    "    first_token_time = None\n",
    "    async with client.messages.stream(\n",
    "        messages=[cached_message],\n",
    "        model=MODEL,\n",
    "        **DEFAULT_CLIENT_ARGS,\n",
    "    ) as stream:\n",
    "        async for text in stream.text_stream:\n",
    "            if first_token_time is None and text.strip():\n",
    "                first_token_time = time.time() - start_time\n",
    "                print(f\"\\n🚀 Time to first token: {first_token_time:.2f} seconds\")\n",
    "                break\n",
    "\n",
    "        # Get the full response\n",
    "        response = await stream.get_final_message()\n",
    "\n",
    "    total_time = time.time() - start_time\n",
    "    print(f\"Total response time: {total_time:.2f} seconds\")\n",
    "    print_query_statistics(response, \"Speculative Caching\")\n",
    "\n",
    "    return first_token_time, total_time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "6c960975",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading SQLite source files...\n",
      "Successfully downloaded btree.h\n",
      "Successfully downloaded btree.c\n",
      "Downloaded 2 files in 0.36 seconds\n",
      "User is typing their question...\n",
      "🔥 Starting cache warming in background...\n",
      "User submitted: What is the purpose of the BtShared structure?\n",
      "✅ Cache warming completed!\n",
      "\n",
      "Sending request to API (with warm cache)...\n",
      "\n",
      "🚀 Time to first token: 1.94 seconds\n",
      "Total response time: 8.40 seconds\n",
      "\n",
      "Speculative Caching query statistics:\n",
      "\tInput tokens: 22\n",
      "\tOutput tokens: 330\n",
      "\tCache read input tokens: 151629\n",
      "\tCache creation input tokens: 0\n"
     ]
    }
   ],
   "source": [
    "# Run the speculative caching demo\n",
    "speculative_ttft, speculative_total = await speculative_prompt_caching_demo()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "deed1898",
   "metadata": {},
   "source": [
    "## Performance Comparison\n",
    "\n",
    "Let's compare the results to see the benefit of speculative caching:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "66b3009e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "PERFORMANCE COMPARISON\n",
      "============================================================\n",
      "\n",
      "Standard Prompt Caching:\n",
      "  Time to First Token: 20.87 seconds\n",
      "  Total Response Time: 28.32 seconds\n",
      "\n",
      "Speculative Prompt Caching:\n",
      "  Time to First Token: 1.94 seconds\n",
      "  Total Response Time: 8.40 seconds\n",
      "\n",
      "🎯 IMPROVEMENTS:\n",
      "  TTFT Improvement: 90.7% (18.93s faster)\n",
      "  Total Time Improvement: 70.4% (19.92s faster)\n"
     ]
    }
   ],
   "source": [
    "print(\"=\" * 60)\n",
    "print(\"PERFORMANCE COMPARISON\")\n",
    "print(\"=\" * 60)\n",
    "\n",
    "print(\"\\nStandard Prompt Caching:\")\n",
    "print(f\"  Time to First Token: {standard_ttft:.2f} seconds\")\n",
    "print(f\"  Total Response Time: {standard_total:.2f} seconds\")\n",
    "\n",
    "print(\"\\nSpeculative Prompt Caching:\")\n",
    "print(f\"  Time to First Token: {speculative_ttft:.2f} seconds\")\n",
    "print(f\"  Total Response Time: {speculative_total:.2f} seconds\")\n",
    "\n",
    "ttft_improvement = (standard_ttft - speculative_ttft) / standard_ttft * 100\n",
    "total_improvement = (standard_total - speculative_total) / standard_total * 100\n",
    "\n",
    "print(\"\\n🎯 IMPROVEMENTS:\")\n",
    "print(\n",
    "    f\"  TTFT Improvement: {ttft_improvement:.1f}% ({standard_ttft - speculative_ttft:.2f}s faster)\"\n",
    ")\n",
    "print(\n",
    "    f\"  Total Time Improvement: {total_improvement:.1f}% ({standard_total - speculative_total:.2f}s faster)\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d93c9bdd",
   "metadata": {},
   "source": [
    "## Key Takeaways\n",
    "\n",
    "1. **Speculative caching dramatically reduces TTFT** by warming the cache while users are typing\n",
    "2. **The pattern is most effective** with large contexts (>1000 tokens) that are reused across queries\n",
    "3. **Implementation is simple** - just send a 1-token request while the user is typing\n",
    "4. **Cache warming happens in parallel** with user input, effectively \"hiding\" the cache creation time\n",
    "\n",
    "## Best Practices\n",
    "\n",
    "- Start cache warming as early as possible (e.g., when a user focuses an input field)\n",
    "- Use exactly the same context for warming and actual requests to ensure cache hits\n",
    "- Monitor `cache_read_input_tokens` to verify cache hits\n",
    "- Add timestamps to prevent unwanted cache sharing across sessions"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
